{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TF-TRT Keras Retinanet Detection Example:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we are going to optimize a Retinanet detection model from the official Keras examples! \n",
    "\n",
    "You can find the implementation here: https://keras.io/examples/vision/retinanet/\n",
    "\n",
    "In general, detection models can be tricky to optimize because they tend to require a lot of custom logic for sub-tasks such as region proposal, output decoding, or non-maximum suppression. This makes them a good demonstration of TF-TRT's capabilities - It does a great job of optimizing a large part of the network while leaving the custom logic untouched."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make sure our GPUs are properly configured and visible:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Jan 29 23:17:01 2021       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.1     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |\n",
      "| N/A   42C    P0    37W / 300W |    125MiB / 16155MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    38W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  Tesla V100-DGXS...  On   | 00000000:0E:00.0 Off |                    0 |\n",
      "| N/A   42C    P0    38W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  Tesla V100-DGXS...  On   | 00000000:0F:00.0 Off |                    0 |\n",
      "| N/A   43C    P0    37W / 300W |      6MiB / 16158MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will also need matplotlib to run the model. If you do not have it, run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (3.3.4)\n",
      "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.17.3)\n",
      "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (0.10.0)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (1.3.1)\n",
      "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.4.7)\n",
      "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (2.8.1)\n",
      "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.6/dist-packages (from matplotlib) (8.1.0)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from cycler>=0.10->matplotlib) (1.15.0)\n",
      "\u001b[33mWARNING: You are using pip version 20.2.3; however, version 21.0 is available.\n",
      "You should consider upgrading via the '/usr/bin/python -m pip install --upgrade pip' command.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remember to sucessfully deploy a TensorRT model, you have to make __five key decisions__:\n",
    "\n",
    "1. __What format should I save my model in?__\n",
    "2. __What batch size(s) am I running inference at?__\n",
    "3. __What precision am I running inference at?__\n",
    "4. __What TensorRT path am I using to convert my model?__\n",
    "5. __What runtime am I targeting?__\n",
    "\n",
    "Let's give it a shot!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. What format should I save my model in?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will work with one of the Keras example RetinaNet implementations. We can download the implementation code for the specific version of it required here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2021-01-29 23:17:05--  https://raw.githubusercontent.com/keras-team/keras-io/cd6201c1bfa37625f503f51e8fd3c572666770e4/examples/vision/retinanet.py\n",
      "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.40.133\n",
      "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.40.133|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 35046 (34K) [text/plain]\n",
      "Saving to: ‘retinanet.py’\n",
      "\n",
      "retinanet.py        100%[===================>]  34.22K  --.-KB/s    in 0.002s  \n",
      "\n",
      "2021-01-29 23:17:05 (20.1 MB/s) - ‘retinanet.py’ saved [35046/35046]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget -O retinanet.py https://raw.githubusercontent.com/keras-team/keras-io/cd6201c1bfa37625f503f51e8fd3c572666770e4/examples/vision/retinanet.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The code has some unnecessary setup steps, so we will pull out just the model implementation itself using sed (you can check the end result in the [retinanet_model.py](./retinanet_model.py) file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!sed -n '1,40 p; 71,820 p' retinanet.py > retinanet_model.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p tmp_savedmodels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We perform some imports and setup:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_size = (224, 224)\n",
    "num_classes = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can import our necessary RetinaNet functions from the example and initialize our detection model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5\n",
      "94773248/94765736 [==============================] - 2s 0us/step\n"
     ]
    }
   ],
   "source": [
    "from retinanet_model import RetinaNet, DecodePredictions, get_backbone\n",
    "\n",
    "resnet50_backbone = get_backbone()\n",
    "model = RetinaNet(num_classes, resnet50_backbone)\n",
    "\n",
    "image = tf.keras.Input(shape=[None, None, 3], name=\"image\")\n",
    "predictions = model(image, training=False)\n",
    "detections = DecodePredictions(confidence_threshold=0.5)(image, predictions)\n",
    "inference_model = tf.keras.Model(inputs=image, outputs=detections)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we save our model in SavedModel format!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Model.state_updates (from tensorflow.python.keras.engine.training) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/tracking/tracking.py:111: Layer.updates (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This property should not be used in TensorFlow 2.0, as updates are applied automatically.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/detect_model/assets\n"
     ]
    }
   ],
   "source": [
    "model_dir = \"tmp_savedmodels/detect_model\"\n",
    "model.save(model_dir) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. What batch size(s) am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will create a dummy batch of size 32:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "dummy_input = np.zeros((32, img_size[0], img_size[1], 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CombinedNonMaxSuppression(nmsed_boxes=array([[[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]],\n",
       "\n",
       "       [[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]],\n",
       "\n",
       "       [[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]],\n",
       "\n",
       "       ...,\n",
       "\n",
       "       [[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]],\n",
       "\n",
       "       [[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]],\n",
       "\n",
       "       [[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]]], dtype=float32), nmsed_scores=array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), nmsed_classes=array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       ...,\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.],\n",
       "       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), valid_detections=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "       0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inference_model.predict(dummy_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. What precision am I running inference at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will stick with the same FP32 precision used during training:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "PRECISION = \"FP32\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. What TensorRT path am I using to convert my model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use our example TF-TRT based ModelOptimizer wrapper:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from helper import ModelOptimizer\n",
    "\n",
    "model_opt = ModelOptimizer(model_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Convert to our target precision, saving the result in a new SavedModel:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Linked TensorRT version: (7, 2, 1)\n",
      "INFO:tensorflow:Loaded TensorRT version: (7, 2, 2)\n",
      "INFO:tensorflow:Loaded TensorRT 7.2.2 and linked TensorFlow against TensorRT 7.2.1. This is supported because TensorRT  minor/patch upgrades are backward compatible\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_2 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_0 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_1 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Could not find TRTEngineOp_0_3 in TF-TRT cache. This can happen if build() is not called, which means TensorRT engines will be built and cached at runtime.\n",
      "INFO:tensorflow:Assets written to: tmp_savedmodels/detect_model_FP32/assets\n",
      "conversion complete! prediction shape: (32, 9441, 14)\n"
     ]
    }
   ],
   "source": [
    "opt_trt = model_opt.convert(model_dir+'_'+PRECISION, precision=PRECISION)\n",
    "print(\"conversion complete! prediction shape:\", opt_trt.predict(dummy_input).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. What TensorRT runtime am I targeting?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will stick to our TF-TRT/Tensorflow runtime:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up...\n",
      "(32, 9441, 14)\n",
      "(32, 9441, 14)\n",
      "Done warming up!\n"
     ]
    }
   ],
   "source": [
    "print(\"Warming up...\")\n",
    "\n",
    "print(model.predict(dummy_input).shape)\n",
    "print(opt_trt.predict(dummy_input).shape)\n",
    "\n",
    "print(\"Done warming up!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Performance Comparisons:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "109 ms ± 5.53 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "\n",
    "preds = model.predict(dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45.1 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "\n",
    "preds = opt_trt.predict(dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
